{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Agent Demo\n",
    "\n",
    "Simple function calling using **langchain4j** framework\n",
    "\n",
    "![diagram](src/site/resources/ReACT.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "var userHomeDir = System.getProperty(\"user.home\");\n",
    "var localRespoUrl = \"file://\" + userHomeDir + \"/.m2/repository/\";\n",
    "var langchain4jVersion = \"1.0.1\";\n",
    "var langchain4jbeta = \"1.0.1-beta6\";\n",
    "var langgraph4jVersion = \"1.6-SNAPSHOT\";"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Remove installed package from Jupiter cache"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash \n",
    "rm -rf \\{userHomeDir}/Library/Jupyter/kernels/rapaio-jupyter-kernel/mima_cache/org/bsc/langgraph4j"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "add local maven repository and install required maven dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%dependency /add-repo local \\{localRespoUrl} release|never snapshot|always\n",
    "// %dependency /list-repos\n",
    "%dependency /add org.slf4j:slf4j-jdk14:2.0.9\n",
    "%dependency /add dev.langchain4j:langchain4j:\\{langchain4jVersion}\n",
    "%dependency /add dev.langchain4j:langchain4j-open-ai:\\{langchain4jVersion}\n",
    "%dependency /add dev.langchain4j:langchain4j-web-search-engine-tavily:\\{langchain4jbeta}\n",
    "\n",
    "%dependency /resolve"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Initialize Logger**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "try( var file = new java.io.FileInputStream(\"./logging.properties\")) {\n",
    "    java.util.logging.LogManager.getLogManager().readConfiguration( file );\n",
    "}\n",
    "\n",
    "var log = org.slf4j.LoggerFactory.getLogger(\"AdaptiveRag\");\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Tools \n",
    "\n",
    "Define tools that will be used by Agent to perfrom actions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dev.langchain4j.agent.tool.P;\n",
    "import dev.langchain4j.agent.tool.Tool;\n",
    "\n",
    "import java.util.Optional;\n",
    "\n",
    "import static java.lang.String.format;\n",
    "\n",
    "public class TestTool {\n",
    "    private String lastResult;\n",
    "\n",
    "    Optional<String> lastResult() {\n",
    "        return Optional.ofNullable(lastResult);\n",
    "    }\n",
    "\n",
    "    @Tool(\"tool for test AI agent executor\")\n",
    "    String execTest(@P(\"test message\") String message) {\n",
    "\n",
    "        lastResult = format( \"test tool executed: %s\", message);\n",
    "\n",
    "        log.info( \"exec test: {}\", lastResult );\n",
    "        return lastResult;\n",
    "    }\n",
    "}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize LLM for Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dev.langchain4j.model.openai.OpenAiChatModel;\n",
    "import java.util.List;\n",
    "import java.util.Map;\n",
    "import java.util.stream.Collectors;\n",
    "\n",
    "var model = OpenAiChatModel.builder()\n",
    "    .apiKey( System.getenv(\"OPENAI_API_KEY\") )\n",
    "    .modelName( \"gpt-4o-mini\" )\n",
    "    .logResponses(true)\n",
    "    .maxRetries(2)\n",
    "    .temperature(0.0)\n",
    "    .maxTokens(2000)\n",
    "    .build();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Invoke Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "AiMessage { text = null toolExecutionRequests = [ToolExecutionRequest { id = \"call_jJl8MsAE101i2sAbh7v8JPdz\", name = \"execTest\", arguments = \"{\"message\":\"bartolomeo\"}\" }] } \n"
     ]
    }
   ],
   "source": [
    "import dev.langchain4j.agent.tool.ToolSpecification;\n",
    "import dev.langchain4j.agent.tool.ToolSpecifications;\n",
    "import dev.langchain4j.data.message.UserMessage;\n",
    "import dev.langchain4j.data.message.SystemMessage;\n",
    "import dev.langchain4j.data.message.AiMessage;\n",
    "import dev.langchain4j.model.output.Response;\n",
    "import dev.langchain4j.model.openai.OpenAiChatModel;\n",
    "import dev.langchain4j.model.chat.request.ChatRequest;\n",
    "import dev.langchain4j.model.chat.request.ChatRequestParameters;\n",
    "\n",
    "var tools = ToolSpecifications.toolSpecificationsFrom( TestTool.class );\n",
    "\n",
    "var systemMessage = SystemMessage.from(\"you are my useful assistant\");\n",
    "var userMessage = UserMessage.from(\"Hi i'm bartolomeo! please test with my name as input\");\n",
    "\n",
    "var params = ChatRequestParameters.builder()\n",
    "                .toolSpecifications( tools )\n",
    "                .build();\n",
    "var request = ChatRequest.builder()\n",
    "                .parameters( params )\n",
    "                .messages( systemMessage, userMessage )\n",
    "                .build();\n",
    "\n",
    "var response = model.chat(request );\n",
    "\n",
    "log.info(  \"{}\", response.aiMessage() );"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dev.langchain4j.agent.tool.Tool;\n",
    "import dev.langchain4j.agent.tool.ToolExecutionRequest;\n",
    "import dev.langchain4j.agent.tool.ToolSpecification;\n",
    "import dev.langchain4j.data.message.ToolExecutionResultMessage;\n",
    "import dev.langchain4j.service.tool.DefaultToolExecutor;\n",
    "import dev.langchain4j.service.tool.ToolExecutor;\n",
    "import static dev.langchain4j.agent.tool.ToolSpecifications.toolSpecificationFrom;\n",
    "\n",
    "import java.lang.reflect.Method;\n",
    "import java.util.*;\n",
    "import java.util.stream.Collectors;\n",
    "\n",
    "public class ToolNode {\n",
    "\n",
    "    record Specification( ToolSpecification value, ToolExecutor executor ) {;\n",
    "        \n",
    "        public Specification(Object objectWithTool, Method method ) {\n",
    "            this( toolSpecificationFrom(method), new DefaultToolExecutor(objectWithTool, method));\n",
    "        }\n",
    "    }\n",
    "\n",
    "    public static ToolNode of( Collection<Object> objectsWithTools) {\n",
    "\n",
    "        List<Specification> toolSpecifications = new ArrayList<>();\n",
    "\n",
    "        for (Object objectWithTool : objectsWithTools ) {\n",
    "            for (Method method : objectWithTool.getClass().getDeclaredMethods()) {\n",
    "                if (method.isAnnotationPresent(Tool.class)) {\n",
    "                    toolSpecifications.add( new Specification( objectWithTool, method));\n",
    "                }\n",
    "            }\n",
    "        }\n",
    "        return new ToolNode(toolSpecifications);\n",
    "    }\n",
    "\n",
    "    public static ToolNode of(Object ...objectsWithTools) {\n",
    "        return of( Arrays.asList(objectsWithTools) );\n",
    "    }\n",
    "\n",
    "    private final List<Specification> entries;\n",
    "\n",
    "    private ToolNode( List<Specification> entries) {\n",
    "        if( entries.isEmpty() ) {\n",
    "            throw new IllegalArgumentException(\"entries cannot be empty!\");\n",
    "        }\n",
    "        this.entries = entries;\n",
    "    }\n",
    "\n",
    "    public List<ToolSpecification> toolSpecifications() {\n",
    "        return this.entries.stream()\n",
    "                .map(Specification::value)\n",
    "                .collect(Collectors.toList());\n",
    "    }\n",
    "\n",
    "    public Optional<ToolExecutionResultMessage> execute( ToolExecutionRequest request, Object memoryId ) {\n",
    "        log.trace( \"execute: {}\", request.name() );\n",
    "\n",
    "        return entries.stream()\n",
    "                .filter( v -> v.value().name().equals(request.name()))\n",
    "                .findFirst()\n",
    "                .map( e -> {\n",
    "                    String value = e.executor().execute(request, memoryId);\n",
    "                    return new ToolExecutionResultMessage( request.id(), request.name(), value );\n",
    "                });\n",
    "    }\n",
    "\n",
    "    public Optional<ToolExecutionResultMessage> execute(Collection<ToolExecutionRequest> requests, Object memoryId ) {\n",
    "        for( ToolExecutionRequest request : requests ) {\n",
    "\n",
    "            Optional<ToolExecutionResultMessage> result = execute( request, memoryId );\n",
    "\n",
    "            if( result.isPresent() ) {\n",
    "                return result;\n",
    "            }\n",
    "        }\n",
    "        return Optional.empty();\n",
    "    }\n",
    "\n",
    "    public Optional<ToolExecutionResultMessage> execute( ToolExecutionRequest request ) {\n",
    "        return execute( request, null );\n",
    "    }\n",
    "\n",
    "    public Optional<ToolExecutionResultMessage> execute( Collection<ToolExecutionRequest> requests ) {\n",
    "        return execute( requests, null );\n",
    "    }\n",
    "\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "execute: execTest \n",
      "exec test: test tool executed: bartolomeo \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "test tool executed: bartolomeo"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "var toolNode = ToolNode.of( new TestTool() );\n",
    "\n",
    "toolNode.execute( response.aiMessage().toolExecutionRequests() )\n",
    "                    .map( m -> m.text() )\n",
    "                    .orElse( null );"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dev.langchain4j.web.search.WebSearchEngine;\n",
    "import dev.langchain4j.web.search.tavily.TavilyWebSearchEngine;\n",
    "import dev.langchain4j.rag.content.retriever.WebSearchContentRetriever;\n",
    "import dev.langchain4j.rag.query.Query;\n",
    "import dev.langchain4j.agent.tool.P;\n",
    "import dev.langchain4j.agent.tool.Tool;\n",
    "import dev.langchain4j.rag.content.Content;\n",
    "import java.util.Optional;\n",
    "\n",
    "import static java.lang.String.format;\n",
    "\n",
    "public class WebSearchTool {\n",
    "    \n",
    "    @Tool(\"tool for search topics on the web\")\n",
    "    public List<Content> execQuery( @P(\"search query\") String query) {\n",
    "\n",
    "        var webSearchEngine = TavilyWebSearchEngine.builder()\n",
    "                .apiKey(System.getenv(\"TAVILY_API_KEY\")) // get a free key: https://app.tavily.com/sign-in\n",
    "                .build();\n",
    "\n",
    "        var webSearchContentRetriever = WebSearchContentRetriever.builder()\n",
    "                .webSearchEngine(webSearchEngine)\n",
    "                .maxResults(3)\n",
    "                .build();\n",
    "\n",
    "        return webSearchContentRetriever.retrieve( new Query( query ) );\n",
    "    }\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "AiMessage { text = null toolExecutionRequests = [ToolExecutionRequest { id = \"call_OiU6RnlHn45mRPq11reHowwB\", name = \"execQuery\", arguments = \"{\"query\":\"100m competition winner Olympic 2024 Paris\"}\" }] } \n"
     ]
    }
   ],
   "source": [
    "import dev.langchain4j.agent.tool.ToolSpecification;\n",
    "import dev.langchain4j.agent.tool.ToolSpecifications;\n",
    "import dev.langchain4j.data.message.UserMessage;\n",
    "import dev.langchain4j.data.message.SystemMessage;\n",
    "import dev.langchain4j.data.message.AiMessage;\n",
    "import dev.langchain4j.model.output.Response;\n",
    "import dev.langchain4j.model.openai.OpenAiChatModel;\n",
    "\n",
    "var toolNode = ToolNode.of( new TestTool(), new WebSearchTool() );\n",
    "\n",
    "var systemMessage = SystemMessage.from(\"you are my useful assistant\");\n",
    "var userMessage = UserMessage.from(\"Who won 100m competition in Olympic 2024 in Paris ?\");\n",
    "\n",
    "var params = ChatRequestParameters.builder()\n",
    "                .toolSpecifications( toolNode.toolSpecifications() )\n",
    "                .build();\n",
    "var request = ChatRequest.builder()\n",
    "                .parameters( params )\n",
    "                .messages( systemMessage, userMessage )\n",
    "                .build();\n",
    "\n",
    "var response = model.chat(request );\n",
    "\n",
    "log.info(  \"{}\", response.aiMessage() );\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java (rjk 2.2.0)",
   "language": "java",
   "name": "rapaio-jupyter-kernel"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "java",
   "nbconvert_exporter": "script",
   "pygments_lexer": "java",
   "version": "22.0.2+9-70"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
